#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2020/09/09 17:00
# @Author  : Wan Diwen
# @FileName: split_dataset.py
import os
import random
import argparse

parser = argparse.ArgumentParser('Split dataset')
parser.add_argument('--root', default=os.path.expanduser('~/data/seg_naic/round2'), help='The root path of dataset')

random.seed(0)
prefix = ''
root = parser.parse_args().root
img_dir = os.path.join(root, 'images')
ann_dir = os.path.join(root, 'labels')

img_lists = os.listdir(img_dir)
ann_lists = os.listdir(ann_dir)
assert len(img_lists) == len(ann_lists)
print(f'There are {len(img_lists)} images')
lists = [int(fn[:-4]) for fn in img_lists]
random.shuffle(lists)

num_train = int(len(img_lists) * 0.9)
train_list = sorted(lists[:num_train])
val_list = sorted(lists[num_train:])
print(f'After Split, there are {len(train_list)} images for train, {len(val_list)} for val')

with open(os.path.join(root, prefix + 'train.txt'), 'w') as f:
    for index in train_list:
        f.writelines(f'{index}\n')
with open(os.path.join(root, prefix + 'val.txt'), 'w') as f:
    for index in val_list:
        f.writelines(f'{index}\n')
with open(os.path.join(root, 'all.txt'), 'w') as f:
    for index in sorted(lists):
        f.writelines(f'{index}\n')
print(f'Successful! train.txt and val.txt are saved in root {root}')
